{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "n_test_batches = 200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 11 - プライバシーに配慮したディープラーニングで分類問題を解く\n",
    "\n",
    "\n",
    "\n",
    "## データの機密性は重要です。と同時に、モデルの機密性も重要です\n",
    "\n",
    "データは機械学習の肝です。組織はデータを作成したり集めたりすることで、独自のモデルをトレーニングすることができ、それをサービス(MLaaS)として外部に公開できます。自分たちでモデルのトレーニングを行えない組織は、公開されたサービスを使って自分たちのデータを推論することができます。\n",
    "\n",
    "しかし、クラウド上のモデルにはプライバシーや知財の問題があります。外部の組織が使おうと思うと、推論したいデータをクラウドにアップロードするか、もしくはモデルをダウンロードする必要があります。入力データのアップロードにはプライバシーの問題がありますし、モデルのダウンロードはモデル所有者が知財を失ってしまうリスクがあります。\n",
    "\n",
    "\n",
    "## 暗号化されたデータを使ってのコンピューテーション\n",
    "\n",
    "こういった状況下における潜在的な解決策は、データとモデルの両方を暗号化し、お互いに知財を非公開とする事です。それを可能にする暗号化手法はいくつか存在します。その中でも、Secure Multi-Party Computation (SMPC)とHomomorphic Encryption (FHE/SHE) 、それに Functional Encryption (FE)はよく知られています。ここでは\"Secure Multi-Party Computation\" ([introduced in detail here in tutorial 5](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%205%20-%20Intro%20to%20Encrypted%20Programs.ipynb))について扱います。\"Secure Multi-Party Computation\"は`shares`を使って暗号化を行う手法でSecureNNやSPDZと呼ばれるライブラリを使用します。詳細は[こちらのブログ](https://mortendahl.github.io/2017/09/19/private-image-analysis-with-mpc/)にてご確認ください。\n",
    "\n",
    "これらのプロトコルは、暗号化されたデータを使ってのコンピューテーションにおいて、目覚ましい成果を上げています。私たちはこれらのプロトコルを開発者が個々に実装することなく（場合によっては裏で動いている暗号技術を意識することもなく）使える仕組みを開発しています。それでは、始めましょう。\n",
    "\n",
    "## セットアップ\n",
    "\n",
    "このチュートリアルに必要な設定は次の通りです。データは手元にあると仮定してください。まず、手元にあるデータを使ってプライバシーに配慮したディープラーニングの手法を使ってモデルの定義とトレーニングを行います。次に何らかのデータを保持していて、モデルを使いたいユーザーと連携します。ここではモデルをトレーニングして公開する主体をサーバー（このケースではあなた）、モデルを使いたいユーザーをクライアントと呼ぶことにします。\n",
    "\n",
    "サーバー（あなた）はモデルを暗号化し、クライアントはデータを暗号化します。あなたとクライアントはどちらも暗号化されたモデルとデータを使ってデータの分類を行います。その後、推論結果を暗号化された状態のままクライアントへ戻します。その際、サーバーはデータについて一切知ることはありません。（入力データ、推論結果のどちらに関してもです。）\n",
    "\n",
    "理想的には`client`も`server`も`shares`をもつべきですが、今回のケースでは簡単のため、`shares`はBobとAliceという2つのリモートワーカーに分配します。もし、aliceはクラアントに、Bobはサーバーに属すと仮定すれば、正にサーバーとクライアントで`shares`を分け合っている状態です。\n",
    "\n",
    "この手法は、悪意の無い関係者間で、安全なコンピューテーションを実現できます。想定する環境は[many MPC frameworks](https://arxiv.org/pdf/1801.03239.pdf)にて標準化されています。\n",
    "ここで言う悪意の無い関係者とは、データがそのまま（閲覧可能な状態で）送られてきたら見てしまうかもしれないけれど、基本的には正直で悪意のない関係者（サーバー、クライアント）という意味です。\n",
    "\n",
    "**準備はと問いました。早速見ていきましょう**\n",
    "\n",
    "\n",
    "Author:\n",
    "- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ライブラリのインポート"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PySyft関連のライブラリをインポートします。何名かのリモートワーカー（ここでは `client`、 `bob`、それに `alice`の3名です）と暗号化技術のプリミティブを提供する`crypto_provider`を作成します。暗号化技術の基本データ型の詳細については[See our tutorial on SMPC for more details](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2009%20-%20Intro%20to%20Encrypted%20Programs.ipynb)を参照してください。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "hook = sy.TorchHook(torch) \n",
    "client = sy.VirtualWorker(hook, id=\"client\")\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "crypto_provider = sy.VirtualWorker(hook, id=\"crypto_provider\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ここで、トレーニングで使用するハイパーパラメータを定義します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 64\n",
    "        self.test_batch_size = 50\n",
    "        self.epochs = epochs\n",
    "        self.lr = 0.001\n",
    "        self.log_interval = 100\n",
    "\n",
    "args = Arguments()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### データの準備\n",
    "\n",
    "今回の設定では、サーバーがモデルと学習データを保持していると仮定しています。今回扱うデータはMNISTです。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=args.batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、クライアントは、サーバーが提供するモデルを使って推論を行いたい、何らかのデータを持っていると仮定しているので、その準備をします。クライアントは`shares`を`alice` と `bob`に分割することでデータを暗号化します。\n",
    "\n",
    "> SMPCは整数で動く暗号化プロトコルを使います。PySyftのtensor拡張機能、`.fix_precision()`を使って不動小数から整数へ変換を行います。例えば、精度を2とすると、0.123は小数点第2位以下が丸められ、12になります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=args.test_batch_size, shuffle=True)\n",
    "\n",
    "private_test_loader = []\n",
    "for data, target in test_loader:\n",
    "    private_test_loader.append((\n",
    "        data.fix_precision().share(alice, bob, crypto_provider=crypto_provider),\n",
    "        target.fix_precision().share(alice, bob, crypto_provider=crypto_provider)\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### モデルの定義\n",
    "\"Feed Forward\"だけからなる基本的なモデルを定義します。このモデルはサーバーによって定義されます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 500)\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 784)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### トレーニングループを定義\n",
    "\n",
    "この学習はサーバーのローカル環境下で行われます。ごく普通のPyTorchのトレーニングです。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(args, model, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        output = F.log_softmax(output, dim=1)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size,\n",
    "                100. * batch_idx / len(train_loader), loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(args, model, train_loader, optimizer, epoch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            output = model(data)\n",
    "            output = F.log_softmax(output, dim=1)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability \n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test(args, model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルの学習が完了しました。準備OKです。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 暗号化されたデータとモデルを使っての評価"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "それでは、クライアントがクライアントのデータに対して推論を行えるよう、モデルをクライアントへ送りましょう。ですが、このモデルはデリケートな情報を含むため（トレーニングで時間と労力がかかっています！）、そのウェイトは非公開にしたいですよね。ここまでのチュートリアルで暗号化されたデータをリモートワーカーへ送ったように。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fix_precision().share(alice, bob, crypto_provider=crypto_provider)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "このテスト関数は暗号化されたデータを使ってのテストができる関数です。モデルのウェイト、入力データ、推論結果、そして正解データは全て暗号化されています。\n",
    "\n",
    "ですが、構文はピュアなPyTorchとほとんど同じですね。\n",
    "\n",
    "唯一サーバー側で複合化するのは最終的な精度のスコアだけです。スコアは推論結果を評価するために必要です。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, test_loader):\n",
    "    model.eval()\n",
    "    n_correct_priv = 0\n",
    "    n_total = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader[:n_test_batches]:\n",
    "            output = model(data)\n",
    "            pred = output.argmax(dim=1) \n",
    "            n_correct_priv += pred.eq(target.view_as(pred)).sum()\n",
    "            n_total += args.test_batch_size\n",
    "            # このテスト関数は暗号化されたデータでの評価を行えます。モデルのウェイト（パラメータ）、入力データ、推論結果、そ\n",
    "            # して正解ラベルと全てが暗号されています。\n",
    "            \n",
    "            # しかしながら、みなさんお気づきの通り、ごくごく一般的なPyTorchのテストスクリプトとほとんど同じです。\n",
    "            \n",
    "            # 唯一複合化しているのは、200アイテムのバッチ事に計算している精度確認のためのすこだけです。\n",
    "            # この数字を見ることで学習されたモデルの性能が良いのか悪いのか評価できます。\n",
    "            \n",
    "            n_correct = n_correct_priv.copy().get().float_precision().long().item()\n",
    "    \n",
    "            print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(\n",
    "                n_correct, n_total,\n",
    "                100. * n_correct / n_total))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test(args, model, private_test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ジャジャーン！今回は暗号化されたデータを使っての推論処理に関する一通りのプロセスを学習しました。モデルのウェイトはクライアント側からは見えませんし、クライアントの入力データや推論結果もサーバー側からは見えません。\n",
    "\n",
    "パフォーマンスに関してですが、1枚の画像の分類に掛かる時間は**0.1秒以下**です。私のノートブック（2.7 GHz Intel Core i7, 16GB RAM）でざっと**33ミリ秒**といったところでしょうか。ですが、今回のチュートリアルでは全てのワーカーが実際には私のマシン上にいるため、通信に時間が掛かっていません。実際の環境でそれぞれのワーカーが別々の場所に存在する場合は、ワーカー間の通信速度に大きく影響を受けます。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "今回のチュートリアルでは、PyTorchとPySyftを使うことで、暗号化技術の専門家でなくても、機密データを使った、実践的、かつセキュアなディープラーニングが簡単に実行できることを学びました。\n",
    "\n",
    "本トピックについてはより多くの事例が追加されていく予定です。畳み込み層を使ったニューラルネットワークや、他のライブラリとのパフォーマンス比較や、外部にある機密データを扱ってのトレーニングなどなどです。お楽しみに。\n",
    "\n",
    "もし、このチュートリアルを気に入って、プライバシーに配慮した非中央集権的なAI技術や付随する（データやモデルの）サプライチェーンにご興味があって、プロジェクトに参加したいと思われるなら、以下の方法で可能です。\n",
    "### PySyftのGitHubレポジトリにスターをつける\n",
    "\n",
    "一番簡単に貢献できる方法はこのGitHubのレポジトリにスターを付けていただくことです。スターが増えると露出が増え、より多くのデベロッパーにこのクールな技術の事を知って貰えます。\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Slackに入る\n",
    "\n",
    "最新の開発状況のトラッキングする一番良い方法はSlackに入ることです。\n",
    "下記フォームから入る事ができます。\n",
    "[http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### コードプロジェクトに参加する\n",
    "\n",
    "コミュニティに貢献する一番良い方法はソースコードのコントリビューターになることです。PySyftのGitHubへアクセスしてIssueのページを開き、\"Projects\"で検索してみてください。参加し得るプロジェクトの状況を把握することができます。また、\"good first issue\"とマークされているIssueを探す事でミニプロジェクトを探すこともできます。\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### 寄付\n",
    "\n",
    "もし、ソースコードで貢献できるほどの時間は取れないけど、是非何かサポートしたいという場合は、寄付をしていただくことも可能です。寄附金の全ては、ハッカソンやミートアップの開催といった、コミュニティ運営経費として利用されます。\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
